{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "rX8mhOLljYeM"
   },
   "source": [
    "##### Copyright 2019 The TensorFlow Authors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "cellView": "form",
    "execution": {
     "iopub.execute_input": "2020-09-22T22:30:31.254688Z",
     "iopub.status.busy": "2020-09-22T22:30:31.253706Z",
     "iopub.status.idle": "2020-09-22T22:30:31.256163Z",
     "shell.execute_reply": "2020-09-22T22:30:31.256582Z"
    },
    "id": "BZSlp3DAjdYf"
   },
   "outputs": [],
   "source": [
    "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "# https://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "3wF5wszaj97Y"
   },
   "source": [
    "# 针对专业人员的 TensorFlow 2.0 入门"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "DUNzJc4jTj6G"
   },
   "source": [
    "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
    "  <td>\n",
    "    <a target=\"_blank\" href=\"https://tensorflow.google.cn/tutorials/quickstart/advanced\"><img src=\"https://tensorflow.google.cn/images/tf_logo_32px.png\" />在 tensorflow.google.cn 上查看</a>\n",
    "  </td>\n",
    "  <td>\n",
    "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/zh-cn/tutorials/quickstart/advanced.ipynb\"><img src=\"https://tensorflow.google.cn/images/colab_logo_32px.png\" />在 Google Colab 中运行</a>\n",
    "  </td>\n",
    "  <td>\n",
    "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs-l10n/blob/master/site/zh-cn/tutorials/quickstart/advanced.ipynb\"><img src=\"https://tensorflow.google.cn/images/GitHub-Mark-32px.png\" />在 GitHub 上查看源代码</a>\n",
    "  </td>\n",
    "  <td>\n",
    "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/zh-cn/tutorials/quickstart/advanced.ipynb\"><img src=\"https://tensorflow.google.cn/images/download_logo_32px.png\" />下载 notebook</a>\n",
    "  </td>\n",
    "</table>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "XBr8cc2zy4NQ"
   },
   "source": [
    "Note: 我们的 TensorFlow 社区翻译了这些文档。因为社区翻译是尽力而为， 所以无法保证它们是最准确的，并且反映了最新的\n",
    "[官方英文文档](https://tensorflow.google.cn/?hl=en)。如果您有改进此翻译的建议， 请提交 pull request 到\n",
    "[tensorflow/docs](https://github.com/tensorflow/docs) GitHub 仓库。要志愿地撰写或者审核译文，请加入\n",
    "[docs-zh-cn@tensorflow.org Google Group](https://groups.google.com/a/tensorflow.org/forum/#!forum/docs-zh-cn)。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "hiH7AC-NTniF"
   },
   "source": [
    "这是一个 [Google Colaboratory](https://colab.research.google.com/notebooks/welcome.ipynb) 笔记本（notebook）文件。Python 程序直接在浏览器中运行——这是一种学习和使用 Tensorflow 的好方法。要学习本教程，请单击本页顶部按钮,在 Google Colab 中运行笔记本（notebook）.\n",
    "\n",
    "1. 在 Colab 中，连接到 Python 运行时：在菜单栏右上角，选择*连接（CONNECT）*。\n",
    "2. 运行所有笔记本（notebook）代码单元格：选择*运行时（Runtime）* > *运行所有（Run all）*。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "eOsVdx6GGHmU"
   },
   "source": [
    "下载并安装 TensorFlow 2.0 Beta 软件包："
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "QS7DDTiZGRTo"
   },
   "source": [
    "将 Tensorflow 导入您的程序："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2020-09-22T22:30:31.261849Z",
     "iopub.status.busy": "2020-09-22T22:30:31.261089Z",
     "iopub.status.idle": "2020-09-22T22:30:37.740516Z",
     "shell.execute_reply": "2020-09-22T22:30:37.739884Z"
    },
    "id": "0trJmd6DjqBZ"
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "from tensorflow.keras.layers import Dense, Flatten, Conv2D\n",
    "from tensorflow.keras import Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "7NAbSZiaoJ4z"
   },
   "source": [
    "加载并准备 [MNIST 数据集](http://yann.lecun.com/exdb/mnist/)。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2020-09-22T22:30:37.747120Z",
     "iopub.status.busy": "2020-09-22T22:30:37.746114Z",
     "iopub.status.idle": "2020-09-22T22:30:38.221445Z",
     "shell.execute_reply": "2020-09-22T22:30:38.220840Z"
    },
    "id": "JqFRS6K07jJs"
   },
   "outputs": [],
   "source": [
    "mnist = tf.keras.datasets.mnist\n",
    "\n",
    "(x_train, y_train), (x_test, y_test) = mnist.load_data()\n",
    "x_train, x_test = x_train / 255.0, x_test / 255.0\n",
    "\n",
    "# Add a channels dimension\n",
    "x_train = x_train[..., tf.newaxis]\n",
    "x_test = x_test[..., tf.newaxis]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "k1Evqx0S22r_"
   },
   "source": [
    "使用 `tf.data` 来将数据集切分为 batch 以及混淆数据集："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2020-09-22T22:30:40.166281Z",
     "iopub.status.busy": "2020-09-22T22:30:39.168850Z",
     "iopub.status.idle": "2020-09-22T22:30:40.222916Z",
     "shell.execute_reply": "2020-09-22T22:30:40.223387Z"
    },
    "id": "8Iu_quO024c2"
   },
   "outputs": [],
   "source": [
    "train_ds = tf.data.Dataset.from_tensor_slices(\n",
    "    (x_train, y_train)).shuffle(10000).batch(32)\n",
    "test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "BPZ68wASog_I"
   },
   "source": [
    "使用 Keras [模型子类化（model subclassing） API](https://tensorflow.google.cn/guide/keras#model_subclassing) 构建 `tf.keras` 模型："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2020-09-22T22:30:40.231066Z",
     "iopub.status.busy": "2020-09-22T22:30:40.229947Z",
     "iopub.status.idle": "2020-09-22T22:30:40.559012Z",
     "shell.execute_reply": "2020-09-22T22:30:40.558275Z"
    },
    "id": "h3IKyzTCDNGo"
   },
   "outputs": [],
   "source": [
    "class MyModel(Model):\n",
    "  def __init__(self):\n",
    "    super(MyModel, self).__init__()\n",
    "    self.conv1 = Conv2D(32, 3, activation='relu')\n",
    "    self.flatten = Flatten()\n",
    "    self.d1 = Dense(128, activation='relu')\n",
    "    self.d2 = Dense(10, activation='softmax')\n",
    "\n",
    "  def call(self, x):\n",
    "    x = self.conv1(x)\n",
    "    x = self.flatten(x)\n",
    "    x = self.d1(x)\n",
    "    return self.d2(x)\n",
    "\n",
    "model = MyModel()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "uGih-c2LgbJu"
   },
   "source": [
    "为训练选择优化器与损失函数："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2020-09-22T22:30:40.564042Z",
     "iopub.status.busy": "2020-09-22T22:30:40.563092Z",
     "iopub.status.idle": "2020-09-22T22:30:40.565286Z",
     "shell.execute_reply": "2020-09-22T22:30:40.565709Z"
    },
    "id": "u48C9WQ774n4"
   },
   "outputs": [],
   "source": [
    "loss_object = tf.keras.losses.SparseCategoricalCrossentropy()\n",
    "\n",
    "optimizer = tf.keras.optimizers.Adam()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "JB6A1vcigsIe"
   },
   "source": [
    "选择衡量指标来度量模型的损失值（loss）和准确率（accuracy）。这些指标在 epoch 上累积值，然后打印出整体结果。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2020-09-22T22:30:40.577927Z",
     "iopub.status.busy": "2020-09-22T22:30:40.576655Z",
     "iopub.status.idle": "2020-09-22T22:30:40.603126Z",
     "shell.execute_reply": "2020-09-22T22:30:40.603574Z"
    },
    "id": "N0MqHFb4F_qn"
   },
   "outputs": [],
   "source": [
    "train_loss = tf.keras.metrics.Mean(name='train_loss')\n",
    "train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')\n",
    "\n",
    "test_loss = tf.keras.metrics.Mean(name='test_loss')\n",
    "test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ix4mEL65on-w"
   },
   "source": [
    "使用 `tf.GradientTape` 来训练模型："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2020-09-22T22:30:40.609855Z",
     "iopub.status.busy": "2020-09-22T22:30:40.608884Z",
     "iopub.status.idle": "2020-09-22T22:30:40.611490Z",
     "shell.execute_reply": "2020-09-22T22:30:40.610897Z"
    },
    "id": "OZACiVqA8KQV"
   },
   "outputs": [],
   "source": [
    "@tf.function\n",
    "def train_step(images, labels):\n",
    "  with tf.GradientTape() as tape:\n",
    "    predictions = model(images)\n",
    "    loss = loss_object(labels, predictions)\n",
    "  gradients = tape.gradient(loss, model.trainable_variables)\n",
    "  optimizer.apply_gradients(zip(gradients, model.trainable_variables))\n",
    "\n",
    "  train_loss(loss)\n",
    "  train_accuracy(labels, predictions)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Z8YT7UmFgpjV"
   },
   "source": [
    "测试模型："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2020-09-22T22:30:40.616527Z",
     "iopub.status.busy": "2020-09-22T22:30:40.615611Z",
     "iopub.status.idle": "2020-09-22T22:30:40.617749Z",
     "shell.execute_reply": "2020-09-22T22:30:40.618293Z"
    },
    "id": "xIKdEzHAJGt7"
   },
   "outputs": [],
   "source": [
    "@tf.function\n",
    "def test_step(images, labels):\n",
    "  predictions = model(images)\n",
    "  t_loss = loss_object(labels, predictions)\n",
    "\n",
    "  test_loss(t_loss)\n",
    "  test_accuracy(labels, predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2020-09-22T22:30:40.624999Z",
     "iopub.status.busy": "2020-09-22T22:30:40.623951Z",
     "iopub.status.idle": "2020-09-22T22:30:59.406637Z",
     "shell.execute_reply": "2020-09-22T22:30:59.406064Z"
    },
    "id": "i-2pkctU_Ci7"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Layer my_model is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2.  The layer has dtype float32 because its dtype defaults to floatx.\n",
      "\n",
      "If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2.\n",
      "\n",
      "To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.\n",
      "\n",
      "Epoch 1, Loss: 0.13994966447353363, Accuracy: 95.9000015258789, Test Loss: 0.06561796367168427, Test Accuracy: 97.88999938964844\n"
     ]
    }
   ],
   "source": [
    "EPOCHS = 5\n",
    "\n",
    "for epoch in range(EPOCHS):\n",
    "  # 在下一个epoch开始时，重置评估指标\n",
    "  train_loss.reset_states()\n",
    "  train_accuracy.reset_states()\n",
    "  test_loss.reset_states()\n",
    "  test_accuracy.reset_states()\n",
    "\n",
    "  for images, labels in train_ds:\n",
    "    train_step(images, labels)\n",
    "\n",
    "  for test_images, test_labels in test_ds:\n",
    "    test_step(test_images, test_labels)\n",
    "\n",
    "  template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'\n",
    "  print (template.format(epoch+1,\n",
    "                         train_loss.result(),\n",
    "                         train_accuracy.result()*100,\n",
    "                         test_loss.result(),\n",
    "                         test_accuracy.result()*100))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "T4JfEh7kvx6m"
   },
   "source": [
    "该图片分类器现在在此数据集上训练得到了接近 98% 的准确率（accuracy）。要了解更多信息，请阅读 [TensorFlow 教程](https://tensorflow.google.cn/tutorials/keras)。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "advanced.ipynb",
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
